package com.zhanghe.study;

import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
import org.apache.commons.math3.fitting.SimpleCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoints;
import org.apache.commons.math3.stat.regression.RegressionResults;
import org.apache.commons.math3.stat.regression.SimpleRegression;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;

/**
 * @author zh
 * @date 2023/6/14 17:08
 */
public class Liner {

    public static void main(String[] args) {
        // TA 3+UV -> 3+UV
//        double[][] data = {
//                {	179557	,	161141		}	,
//                {	361486	,	350087		}	,
//                {	726527	,	567600		}	,
//                {	970678	,	681703		}	,
//                {	1326831	,	950184		}	,
//                {	1850689	,	1392760		}	,
//                {	2205657	,	1713404		}	,
//                {	51556	,	48677		}	,
//                {	316448	,	311383		}	,
//                {	652452	,	600481		}	,
//                {	26667	,	24870		}	,
//                {	334819	,	387935		}	,
//                {	708749	,	834852		}	,
//                {	853344	,	1008909		}	,
//                {	1816911	,	2117206		}	,
//                {	253632	,	260389		}	,
//                {	412449	,	427012		}	,
//                {	587266	,	612709		}	,
//                {	813639	,	888344		}	,
//                {	1026383	,	1152681		}	,
//                {	1466092	,	1651871		}	,
//                {	1633798	,	1881235		}	,
//                {	210373	,	152505		}	,
//                {	402968	,	347683		}	,
//                {	502524	,	436943		}	,
//                {	585488	,	532062		}	,
//                {	852751	,	865074		}	,
//                {	1106384	,	1203872		}	,
//                {	1382535	,	1541069		}	,
//                {	1816911	,	2112498		}
//
//        };

        double[][] data = {
//                {	0	,	0	}	,
//                {	26667	,	24890	}	,
                {334819, 387951},
                {708749, 834894},
                {853344, 1008967},
                {1816911, 2117220},
                {253632, 225778},
                {412449, 380415},
                {587266, 562633},
                {813639, 837258},
                {1026383, 1100257},
                {1466092, 1577641},
                {1633798, 1794237},
                {210373, 169358},
                {402968, 373542},
                {502524, 485219},
                {585488, 583774},
                {852751, 919957},
                {1106384, 1262990},
                {1382535, 1605248},
                {1816911, 2169666}

        };


        // 2144942	2143437
        //2702372	2700734
        //3199030	3197250
        //3798707	3796720
        //4381469	4379272
        //5970588	5969517
        //6544195	6543375
//        double[][] data = {{2144942,2143437},{2702372,2700734},{3199030,3197250},{3798707,3796720},
//                {4381469,4379272},{5970588,5969517},{6544195,6543375}};
        linearFit(data);
//        curveFit(data);
//        customizeFuncFit(data);
    }


    public static void linearFit(double[][] data) {
        // f(x)=kx+b
        List<double[]> fitData = new ArrayList<>();
        SimpleRegression regression = new SimpleRegression();
        regression.addData(data); // 数据集
        /*
         * RegressionResults 中是拟合的结果
         * 其中重要的几个参数如下：
         *   parameters:
         *      0: b
         *      1: k
         *   globalFitInfo
         *      0: 平方误差之和, SSE
         *      1: 平方和, SST
         *      2: R 平方, RSQ
         *      3: 均方误差, MSE
         *      4: 调整后的 R 平方, adjRSQ
         *
         * */
        RegressionResults results = regression.regress();

        double b = results.getParameterEstimate(0);
        double k = results.getParameterEstimate(1);
        double r2 = results.getRSquared();

        double intercept = regression.getIntercept();
        double slope = regression.getSlope();
        System.out.println("intercept:" + intercept + "--" + b);
        System.out.println("slope:" + slope + "--" + k);
        double correlation = Mutil.correlation(data);
        double adjR2 = results.getAdjustedRSquared();
        System.out.println("r2:" + r2+"---"+correlation+"--"+adjR2);




        // 总离差平方和
        double tss = results.getTotalSumSquares();
        // 重新计算生成拟合曲线
        for (double[] datum : data) {
            double[] xy = {datum[0], k * datum[0] + b};
            fitData.add(xy);
        }

        StringBuilder func = new StringBuilder();
        func.append("f(x) =");
        func.append(b >= 0 ? " " : " - ");
        func.append(Math.abs(b));
        func.append(k > 0 ? " + " : " - ");
        func.append(Math.abs(k));
        func.append("x");
        System.out.println(func.toString());

        fitData.forEach(
                ds ->
                        System.out.println(ds[0] + "," + ds[1])
        );

    }


    public static void curveFit(double[][] data) {
        ParametricUnivariateFunction function = new PolynomialFunction.Parametric();/*多项式函数*/
        double[] guess = {0, 0, 0, 0}; /*猜测值 依次为 常数项、1次项、二次项*/

        // 初始化拟合
        SimpleCurveFitter curveFitter = SimpleCurveFitter.create(function, guess);

        // 添加数据点
        WeightedObservedPoints observedPoints = new WeightedObservedPoints();
        for (double[] point : data) {
            observedPoints.add(point[0], point[1]);
        }
        /*
         * best 为拟合结果
         * 依次为 常数项、1次项、二次项
         * 对应 y = a + bx + cx^2 中的 a, b, c
         * */
        double[] best = curveFitter.fit(observedPoints.toList());

        /*
         * 根据拟合结果重新计算
         * */
        List<double[]> fitData = new ArrayList<>();
        for (double[] datum : data) {
            double x = datum[0];
            double y = best[0] + best[1] * x + best[2] * x * x; // y = a + bx + cx^2
            double[] xy = {x, y};
            fitData.add(xy);
        }


        System.out.println(best[0]);
        System.out.println(best[1]);
        System.out.println(best[2]);

        StringBuilder func = new StringBuilder();
        func.append("f(x) =");
        func.append(best[0] > 0 ? " " : " - ");

        func.append(BigDecimal.valueOf(Math.abs(best[0])).setScale(3, BigDecimal.ROUND_HALF_UP).doubleValue());
        func.append(best[1] > 0 ? " + " : " - ");
        func.append(BigDecimal.valueOf(Math.abs(best[1])).setScale(3, BigDecimal.ROUND_HALF_UP).doubleValue());
        func.append("x");
        func.append(best[2] > 0 ? " + " : " - ");
        func.append(BigDecimal.valueOf(Math.abs(best[2])).setScale(10, BigDecimal.ROUND_HALF_UP).doubleValue());
        func.append("x^2");

        System.out.println(func.toString());

        fitData.forEach(
                ds ->
                        System.out.println(ds[0] + "," + ds[1])
        );


    }

    static class MyFunction implements ParametricUnivariateFunction {
        public double value(double x, double... parameters) {
            double a = parameters[0];
            double b = parameters[1];
            double c = parameters[2];
            double d = parameters[3];
            // d + ((a - d) / (1 + Math.pow(x / c, b)))
            return d + (a / (Math.pow(x / c, b)));
        }

        public double[] gradient(double x, double... parameters) {
            double a = parameters[0];
            double b = parameters[1];
            double c = parameters[2];
            double d = parameters[3];

            double[] gradients = new double[4];
            double den = Math.pow(x / c, b);

            gradients[0] = 1 / den; // 对 a 求导

            gradients[1] = ((a) * Math.pow(x / c, b) * Math.log(x / c)) / (den * den); // 对 b 求导

            gradients[2] = (b * Math.pow(x / c, b - 1) * (x / (c * c)) * (a - d)) / (den * den); // 对 c 求导

            gradients[3] = 1 - (1 / den); // 对 d 求导

            return gradients;

        }
    }

    public static void customizeFuncFit(double[][] scatters) {
        ParametricUnivariateFunction function = new MyFunction();/*多项式函数*/
        double[] guess = {1, 1, 1, 1}; /*猜测值 依次为 a b c d 。必须和 gradient 方法返回数组对应。如果不知道都设置为 1*/

        // 初始化拟合
        SimpleCurveFitter curveFitter = SimpleCurveFitter.create(function, guess);

        // 添加数据点
        WeightedObservedPoints observedPoints = new WeightedObservedPoints();
        for (double[] point : scatters) {
            observedPoints.add(point[0], point[1]);
        }

        /*
         * best 为拟合结果 对应 a b c d
         * 可能会出现无法拟合的情况
         * 需要合理设置初始值
         * */
        double[] best = curveFitter.fit(observedPoints.toList());
        double a = best[0];
        double b = best[1];
        double c = best[2];
        double d = best[3];

        // 根据拟合结果生成拟合曲线散点
        List<double[]> fitData = new ArrayList<>();
        for (double[] datum : scatters) {
            double x = datum[0];
            double y = function.value(x, a, b, c, d);
            double[] xy = {x, y};
            fitData.add(xy);
        }

        // f(x) = d + ((a - d) / (1 + Math.pow(x / c, b)))
        StringBuilder func = new StringBuilder();
        func.append("f(x) =");
        func.append(d > 0 ? " " : " - ");
        func.append(Math.abs(d));
        func.append(" ((");
        func.append(a > 0 ? "" : "-");
        func.append(Math.abs(a));
        func.append(d > 0 ? " - " : " + ");
        func.append(Math.abs(d));
        func.append(" / (1 + ");
        func.append("(x / ");
        func.append(c > 0 ? "" : " - ");
        func.append(Math.abs(c));
        func.append(") ^ ");
        func.append(b > 0 ? " " : " - ");
        func.append(Math.abs(b));

        System.out.println(func.toString());

        fitData.forEach(
                ds ->
                        System.out.println(ds[0] + "," + ds[1])
        );

    }


}
